{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "\n",
    "import gc\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "pd.set_option('max_columns', None)\n",
    "pd.set_option('max_rows', 100)\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.model_selection import KFold, StratifiedKFold\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.decomposition import TruncatedSVD\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "import jieba\n",
    "\n",
    "import lightgbm as lgb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(12000, 7)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>level_1</th>\n",
       "      <th>level_2</th>\n",
       "      <th>level_3</th>\n",
       "      <th>level_4</th>\n",
       "      <th>content</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（二）电气安全</td>\n",
       "      <td>6、移动用电产品、电动工具及照明</td>\n",
       "      <td>1、移动使用的用电产品和I类电动工具的绝缘线，必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。</td>\n",
       "      <td>使用移动手动电动工具,外接线绝缘皮破损,应停止使用.</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>3、消防设施、器材和消防安全标志是否在位、完整；</td>\n",
       "      <td>一般</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>2、防火检查</td>\n",
       "      <td>6、重点工种人员以及其他员工消防知识的掌握情况；</td>\n",
       "      <td>消防知识要加强</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>3、消防设施、器材和消防安全标志是否在位、完整；</td>\n",
       "      <td>消防通道有货物摆放 清理不及时</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>4、常闭式防火门是否处于关闭状态，防火卷帘下是否堆放物品影响使用；</td>\n",
       "      <td>防火门打开状态</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>5</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>2、防火检查</td>\n",
       "      <td>8、易燃易爆危险物品和场所防火防爆措施的落实情况以及其他重要物资的防火安全情况；</td>\n",
       "      <td>防爆柜里面稀释剂，机油费混装</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>6</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>2、安全出口、疏散通道是否畅通，安全疏散指示标志、应急照明是否完好；</td>\n",
       "      <td>已经整改</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>7</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>2、安全出口、疏散通道是否畅通，安全疏散指示标志、应急照明是否完好；</td>\n",
       "      <td>逃生通道有货物阻挡。</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>8</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>2、防火检查</td>\n",
       "      <td>2、安全疏散通道、疏散指示标志、应急照明和安全出口情况；</td>\n",
       "      <td>已整改</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>9</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（四）作业环境</td>\n",
       "      <td>1、作业通道</td>\n",
       "      <td>1、作业通道应保持畅通，禁止临时堆放货物；通道以黄色或者白色线标明。凡有地坑、壕、池的地方,...</td>\n",
       "      <td>通道黄色线标脱落，已及时重新标好线标。</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id            level_1  level_2           level_3  \\\n",
       "0   0  工业/危化品类（现场）—2016版  （二）电气安全  6、移动用电产品、电动工具及照明   \n",
       "1   1  工业/危化品类（现场）—2016版  （一）消防检查            1、防火巡查   \n",
       "2   2  工业/危化品类（现场）—2016版  （一）消防检查            2、防火检查   \n",
       "3   3  工业/危化品类（现场）—2016版  （一）消防检查            1、防火巡查   \n",
       "4   4  工业/危化品类（现场）—2016版  （一）消防检查            1、防火巡查   \n",
       "5   5  工业/危化品类（现场）—2016版  （一）消防检查            2、防火检查   \n",
       "6   6  工业/危化品类（现场）—2016版  （一）消防检查            1、防火巡查   \n",
       "7   7  工业/危化品类（现场）—2016版  （一）消防检查            1、防火巡查   \n",
       "8   8  工业/危化品类（现场）—2016版  （一）消防检查            2、防火检查   \n",
       "9   9  工业/危化品类（现场）—2016版  （四）作业环境            1、作业通道   \n",
       "\n",
       "                                             level_4  \\\n",
       "0  1、移动使用的用电产品和I类电动工具的绝缘线，必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。   \n",
       "1                           3、消防设施、器材和消防安全标志是否在位、完整；   \n",
       "2                           6、重点工种人员以及其他员工消防知识的掌握情况；   \n",
       "3                           3、消防设施、器材和消防安全标志是否在位、完整；   \n",
       "4                  4、常闭式防火门是否处于关闭状态，防火卷帘下是否堆放物品影响使用；   \n",
       "5           8、易燃易爆危险物品和场所防火防爆措施的落实情况以及其他重要物资的防火安全情况；   \n",
       "6                 2、安全出口、疏散通道是否畅通，安全疏散指示标志、应急照明是否完好；   \n",
       "7                 2、安全出口、疏散通道是否畅通，安全疏散指示标志、应急照明是否完好；   \n",
       "8                       2、安全疏散通道、疏散指示标志、应急照明和安全出口情况；   \n",
       "9  1、作业通道应保持畅通，禁止临时堆放货物；通道以黄色或者白色线标明。凡有地坑、壕、池的地方,...   \n",
       "\n",
       "                      content  label  \n",
       "0  使用移动手动电动工具,外接线绝缘皮破损,应停止使用.      0  \n",
       "1                          一般      1  \n",
       "2                     消防知识要加强      0  \n",
       "3             消防通道有货物摆放 清理不及时      0  \n",
       "4                     防火门打开状态      0  \n",
       "5              防爆柜里面稀释剂，机油费混装      0  \n",
       "6                        已经整改      1  \n",
       "7                  逃生通道有货物阻挡。      0  \n",
       "8                         已整改      1  \n",
       "9         通道黄色线标脱落，已及时重新标好线标。      0  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('raw_data/train.csv')\n",
    "\n",
    "print(train.shape)\n",
    "train.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(18000, 6)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>level_1</th>\n",
       "      <th>level_2</th>\n",
       "      <th>level_3</th>\n",
       "      <th>level_4</th>\n",
       "      <th>content</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>交通运输类（现场）—2016版</td>\n",
       "      <td>（一）消防安全</td>\n",
       "      <td>2、防火检查</td>\n",
       "      <td>2、安全疏散通道、疏散指示标志、应急照明和安全出口情况。</td>\n",
       "      <td>RB1洗地机占用堵塞安全通道</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>工业/危化品类（选项）—2016版</td>\n",
       "      <td>（二）仓库</td>\n",
       "      <td>1、一般要求</td>\n",
       "      <td>1、库房内储存物品应分类、分堆、限额存放。</td>\n",
       "      <td>未分类堆放</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>3、消防设施、器材和消防安全标志是否在位、完整；</td>\n",
       "      <td>消防设施、器材和消防安全标志是否在位、完整</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>商贸服务教文卫类（现场）—2016版</td>\n",
       "      <td>（二）电气安全</td>\n",
       "      <td>3、电气线路及电源插头插座</td>\n",
       "      <td>3、电源插座、电源插头应按规定正确接线。</td>\n",
       "      <td>插座随意放在电器旁边</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>商贸服务教文卫类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>6、其他消防安全情况。</td>\n",
       "      <td>检查中发现一瓶灭火器过期</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>5</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>2、防火检查</td>\n",
       "      <td>4、灭火器材配置及有效情况；</td>\n",
       "      <td>灭火器过期更换</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>6</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>2、防火检查</td>\n",
       "      <td>11、消防安全标志的设置情况和完好、有效情况；</td>\n",
       "      <td>仓库的墙面上 未贴严禁烟火标志， 已进行整改</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>7</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>2、防火检查</td>\n",
       "      <td>11、消防安全标志的设置情况和完好、有效情况；</td>\n",
       "      <td>部分消防标志褪色，已更换！</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>8</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>3、消防设施、器材和消防安全标志是否在位、完整；</td>\n",
       "      <td>手推车放在灭火器前，阻挡灭火器</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>9</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>2、防火检查</td>\n",
       "      <td>12、其他需要检查的内容。　　防火检查需填写检查记录。检查人员和被检查部门负责人在检查记录上签名。</td>\n",
       "      <td>消防栓未定时检查</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id             level_1  level_2        level_3  \\\n",
       "0   0     交通运输类（现场）—2016版  （一）消防安全         2、防火检查   \n",
       "1   1   工业/危化品类（选项）—2016版    （二）仓库         1、一般要求   \n",
       "2   2   工业/危化品类（现场）—2016版  （一）消防检查         1、防火巡查   \n",
       "3   3  商贸服务教文卫类（现场）—2016版  （二）电气安全  3、电气线路及电源插头插座   \n",
       "4   4  商贸服务教文卫类（现场）—2016版  （一）消防检查         1、防火巡查   \n",
       "5   5   工业/危化品类（现场）—2016版  （一）消防检查         2、防火检查   \n",
       "6   6   工业/危化品类（现场）—2016版  （一）消防检查         2、防火检查   \n",
       "7   7   工业/危化品类（现场）—2016版  （一）消防检查         2、防火检查   \n",
       "8   8   工业/危化品类（现场）—2016版  （一）消防检查         1、防火巡查   \n",
       "9   9   工业/危化品类（现场）—2016版  （一）消防检查         2、防火检查   \n",
       "\n",
       "                                             level_4                 content  \n",
       "0                       2、安全疏散通道、疏散指示标志、应急照明和安全出口情况。          RB1洗地机占用堵塞安全通道  \n",
       "1                              1、库房内储存物品应分类、分堆、限额存放。                   未分类堆放  \n",
       "2                           3、消防设施、器材和消防安全标志是否在位、完整；   消防设施、器材和消防安全标志是否在位、完整  \n",
       "3                               3、电源插座、电源插头应按规定正确接线。              插座随意放在电器旁边  \n",
       "4                                        6、其他消防安全情况。            检查中发现一瓶灭火器过期  \n",
       "5                                     4、灭火器材配置及有效情况；                 灭火器过期更换  \n",
       "6                            11、消防安全标志的设置情况和完好、有效情况；  仓库的墙面上 未贴严禁烟火标志， 已进行整改  \n",
       "7                            11、消防安全标志的设置情况和完好、有效情况；           部分消防标志褪色，已更换！  \n",
       "8                           3、消防设施、器材和消防安全标志是否在位、完整；         手推车放在灭火器前，阻挡灭火器  \n",
       "9  12、其他需要检查的内容。　　防火检查需填写检查记录。检查人员和被检查部门负责人在检查记录上签名。                消防栓未定时检查  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test = pd.read_csv('raw_data/test.csv')\n",
    "\n",
    "print(test.shape)\n",
    "test.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# bert prob\n",
    "\n",
    "train_bert_pred = pd.read_csv('roberta_pred_oof.csv')\n",
    "test_bert_pred = pd.read_csv('roberta_pred_test.csv')\n",
    "\n",
    "train = pd.merge(train, train_bert_pred, on='id')\n",
    "test = pd.merge(test, test_bert_pred, on='id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(30000, 8)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>level_1</th>\n",
       "      <th>level_2</th>\n",
       "      <th>level_3</th>\n",
       "      <th>level_4</th>\n",
       "      <th>content</th>\n",
       "      <th>label</th>\n",
       "      <th>bert_pred</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（二）电气安全</td>\n",
       "      <td>6、移动用电产品、电动工具及照明</td>\n",
       "      <td>1、移动使用的用电产品和I类电动工具的绝缘线，必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。</td>\n",
       "      <td>使用移动手动电动工具,外接线绝缘皮破损,应停止使用.</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.096279</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>3、消防设施、器材和消防安全标志是否在位、完整；</td>\n",
       "      <td>一般</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.496831</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>2、防火检查</td>\n",
       "      <td>6、重点工种人员以及其他员工消防知识的掌握情况；</td>\n",
       "      <td>消防知识要加强</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.056331</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>3、消防设施、器材和消防安全标志是否在位、完整；</td>\n",
       "      <td>消防通道有货物摆放 清理不及时</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.817928</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>4、常闭式防火门是否处于关闭状态，防火卷帘下是否堆放物品影响使用；</td>\n",
       "      <td>防火门打开状态</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-2.853796</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>5</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>2、防火检查</td>\n",
       "      <td>8、易燃易爆危险物品和场所防火防爆措施的落实情况以及其他重要物资的防火安全情况；</td>\n",
       "      <td>防爆柜里面稀释剂，机油费混装</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.080718</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>6</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>2、安全出口、疏散通道是否畅通，安全疏散指示标志、应急照明是否完好；</td>\n",
       "      <td>已经整改</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.261704</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>7</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>1、防火巡查</td>\n",
       "      <td>2、安全出口、疏散通道是否畅通，安全疏散指示标志、应急照明是否完好；</td>\n",
       "      <td>逃生通道有货物阻挡。</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.527178</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>8</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（一）消防检查</td>\n",
       "      <td>2、防火检查</td>\n",
       "      <td>2、安全疏散通道、疏散指示标志、应急照明和安全出口情况；</td>\n",
       "      <td>已整改</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.132907</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>9</td>\n",
       "      <td>工业/危化品类（现场）—2016版</td>\n",
       "      <td>（四）作业环境</td>\n",
       "      <td>1、作业通道</td>\n",
       "      <td>1、作业通道应保持畅通，禁止临时堆放货物；通道以黄色或者白色线标明。凡有地坑、壕、池的地方,...</td>\n",
       "      <td>通道黄色线标脱落，已及时重新标好线标。</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.748759</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id            level_1  level_2           level_3  \\\n",
       "0   0  工业/危化品类（现场）—2016版  （二）电气安全  6、移动用电产品、电动工具及照明   \n",
       "1   1  工业/危化品类（现场）—2016版  （一）消防检查            1、防火巡查   \n",
       "2   2  工业/危化品类（现场）—2016版  （一）消防检查            2、防火检查   \n",
       "3   3  工业/危化品类（现场）—2016版  （一）消防检查            1、防火巡查   \n",
       "4   4  工业/危化品类（现场）—2016版  （一）消防检查            1、防火巡查   \n",
       "5   5  工业/危化品类（现场）—2016版  （一）消防检查            2、防火检查   \n",
       "6   6  工业/危化品类（现场）—2016版  （一）消防检查            1、防火巡查   \n",
       "7   7  工业/危化品类（现场）—2016版  （一）消防检查            1、防火巡查   \n",
       "8   8  工业/危化品类（现场）—2016版  （一）消防检查            2、防火检查   \n",
       "9   9  工业/危化品类（现场）—2016版  （四）作业环境            1、作业通道   \n",
       "\n",
       "                                             level_4  \\\n",
       "0  1、移动使用的用电产品和I类电动工具的绝缘线，必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。   \n",
       "1                           3、消防设施、器材和消防安全标志是否在位、完整；   \n",
       "2                           6、重点工种人员以及其他员工消防知识的掌握情况；   \n",
       "3                           3、消防设施、器材和消防安全标志是否在位、完整；   \n",
       "4                  4、常闭式防火门是否处于关闭状态，防火卷帘下是否堆放物品影响使用；   \n",
       "5           8、易燃易爆危险物品和场所防火防爆措施的落实情况以及其他重要物资的防火安全情况；   \n",
       "6                 2、安全出口、疏散通道是否畅通，安全疏散指示标志、应急照明是否完好；   \n",
       "7                 2、安全出口、疏散通道是否畅通，安全疏散指示标志、应急照明是否完好；   \n",
       "8                       2、安全疏散通道、疏散指示标志、应急照明和安全出口情况；   \n",
       "9  1、作业通道应保持畅通，禁止临时堆放货物；通道以黄色或者白色线标明。凡有地坑、壕、池的地方,...   \n",
       "\n",
       "                      content  label  bert_pred  \n",
       "0  使用移动手动电动工具,外接线绝缘皮破损,应停止使用.    0.0  -3.096279  \n",
       "1                          一般    1.0   2.496831  \n",
       "2                     消防知识要加强    0.0   1.056331  \n",
       "3             消防通道有货物摆放 清理不及时    0.0  -3.817928  \n",
       "4                     防火门打开状态    0.0  -2.853796  \n",
       "5              防爆柜里面稀释剂，机油费混装    0.0  -3.080718  \n",
       "6                        已经整改    1.0   3.261704  \n",
       "7                  逃生通道有货物阻挡。    0.0  -3.527178  \n",
       "8                         已整改    1.0   3.132907  \n",
       "9         通道黄色线标脱落，已及时重新标好线标。    0.0  -3.748759  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.concat([train, test])\n",
    "print(df.shape)\n",
    "\n",
    "df.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "level_1 19\n",
      "level_2 78\n",
      "level_3 185\n",
      "level_4 379\n",
      "content 23572\n",
      "label 2\n",
      "bert_pred 25229\n"
     ]
    }
   ],
   "source": [
    "for col in [f for f in df.columns if f != 'id']:\n",
    "    print(col, df[col].nunique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b158cc55dd3d4d22b14860bf026c4d7a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for col in tqdm(['level_1', 'level_2', 'level_3', 'level_4']):\n",
    "    df[f'{col}_strlen'] = df[col].astype(str).apply(len)\n",
    "    lbl = LabelEncoder()\n",
    "    lbl.fit(df[col])\n",
    "    df[col] = lbl.transform(df[col])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['content_strlen'] = df['content'].astype(str).apply(len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>level_1</th>\n",
       "      <th>level_2</th>\n",
       "      <th>level_3</th>\n",
       "      <th>level_4</th>\n",
       "      <th>content</th>\n",
       "      <th>label</th>\n",
       "      <th>bert_pred</th>\n",
       "      <th>level_1_strlen</th>\n",
       "      <th>level_2_strlen</th>\n",
       "      <th>level_3_strlen</th>\n",
       "      <th>level_4_strlen</th>\n",
       "      <th>content_strlen</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>12</td>\n",
       "      <td>32</td>\n",
       "      <td>170</td>\n",
       "      <td>136</td>\n",
       "      <td>使用移动手动电动工具,外接线绝缘皮破损,应停止使用.</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.096279</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>16</td>\n",
       "      <td>49</td>\n",
       "      <td>26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>12</td>\n",
       "      <td>8</td>\n",
       "      <td>56</td>\n",
       "      <td>257</td>\n",
       "      <td>一般</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.496831</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>24</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>12</td>\n",
       "      <td>8</td>\n",
       "      <td>101</td>\n",
       "      <td>353</td>\n",
       "      <td>消防知识要加强</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.056331</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>24</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>12</td>\n",
       "      <td>8</td>\n",
       "      <td>56</td>\n",
       "      <td>257</td>\n",
       "      <td>消防通道有货物摆放 清理不及时</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.817928</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>24</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>12</td>\n",
       "      <td>8</td>\n",
       "      <td>56</td>\n",
       "      <td>287</td>\n",
       "      <td>防火门打开状态</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-2.853796</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>33</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id  level_1  level_2  level_3  level_4                     content  label  \\\n",
       "0   0       12       32      170      136  使用移动手动电动工具,外接线绝缘皮破损,应停止使用.    0.0   \n",
       "1   1       12        8       56      257                          一般    1.0   \n",
       "2   2       12        8      101      353                     消防知识要加强    0.0   \n",
       "3   3       12        8       56      257             消防通道有货物摆放 清理不及时    0.0   \n",
       "4   4       12        8       56      287                     防火门打开状态    0.0   \n",
       "\n",
       "   bert_pred  level_1_strlen  level_2_strlen  level_3_strlen  level_4_strlen  \\\n",
       "0  -3.096279              17               7              16              49   \n",
       "1   2.496831              17               7               6              24   \n",
       "2   1.056331              17               7               6              24   \n",
       "3  -3.817928              17               7               6              24   \n",
       "4  -2.853796              17               7               6              33   \n",
       "\n",
       "   content_strlen  \n",
       "0              26  \n",
       "1               2  \n",
       "2               7  \n",
       "3              15  \n",
       "4               7  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Dumping model to file cache /tmp/jieba.cache\n",
      "Loading model cost 1.105 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "df['content'].fillna('', inplace=True)\n",
    "df['content_seg'] = df['content'].apply(lambda x: \" \".join(jieba.cut(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['content_word_cnt'] = df['content_seg'].apply(lambda x: len(x.split(\" \")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_components = 16\n",
    "\n",
    "X = list(df['content_seg'].values)\n",
    "tfv = TfidfVectorizer(ngram_range=(1,1), \n",
    "                      token_pattern=r\"(?u)\\b[^ ]+\\b\",\n",
    "                      max_features=10000)\n",
    "tfv.fit(X)\n",
    "X_tfidf = tfv.transform(X)\n",
    "svd = TruncatedSVD(n_components=n_components)\n",
    "svd.fit(X_tfidf)\n",
    "X_svd = svd.transform(X_tfidf)\n",
    "\n",
    "for i in range(n_components):\n",
    "    df[f'content_tfidf_{i}'] = X_svd[:, i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>level_1</th>\n",
       "      <th>level_2</th>\n",
       "      <th>level_3</th>\n",
       "      <th>level_4</th>\n",
       "      <th>content</th>\n",
       "      <th>label</th>\n",
       "      <th>bert_pred</th>\n",
       "      <th>level_1_strlen</th>\n",
       "      <th>level_2_strlen</th>\n",
       "      <th>level_3_strlen</th>\n",
       "      <th>level_4_strlen</th>\n",
       "      <th>content_strlen</th>\n",
       "      <th>content_seg</th>\n",
       "      <th>content_word_cnt</th>\n",
       "      <th>content_tfidf_0</th>\n",
       "      <th>content_tfidf_1</th>\n",
       "      <th>content_tfidf_2</th>\n",
       "      <th>content_tfidf_3</th>\n",
       "      <th>content_tfidf_4</th>\n",
       "      <th>content_tfidf_5</th>\n",
       "      <th>content_tfidf_6</th>\n",
       "      <th>content_tfidf_7</th>\n",
       "      <th>content_tfidf_8</th>\n",
       "      <th>content_tfidf_9</th>\n",
       "      <th>content_tfidf_10</th>\n",
       "      <th>content_tfidf_11</th>\n",
       "      <th>content_tfidf_12</th>\n",
       "      <th>content_tfidf_13</th>\n",
       "      <th>content_tfidf_14</th>\n",
       "      <th>content_tfidf_15</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>12</td>\n",
       "      <td>32</td>\n",
       "      <td>170</td>\n",
       "      <td>136</td>\n",
       "      <td>使用移动手动电动工具,外接线绝缘皮破损,应停止使用.</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.096279</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>16</td>\n",
       "      <td>49</td>\n",
       "      <td>26</td>\n",
       "      <td>使用 移动 手动 电动工具 , 外 接线 绝缘 皮 破损 , 应 停止使用 .</td>\n",
       "      <td>14</td>\n",
       "      <td>0.006962</td>\n",
       "      <td>0.004836</td>\n",
       "      <td>0.009002</td>\n",
       "      <td>0.005316</td>\n",
       "      <td>0.001656</td>\n",
       "      <td>0.002731</td>\n",
       "      <td>-0.004912</td>\n",
       "      <td>0.011597</td>\n",
       "      <td>0.011524</td>\n",
       "      <td>-0.001966</td>\n",
       "      <td>0.007598</td>\n",
       "      <td>-0.011693</td>\n",
       "      <td>-0.008630</td>\n",
       "      <td>-0.008527</td>\n",
       "      <td>0.010131</td>\n",
       "      <td>0.003516</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>12</td>\n",
       "      <td>8</td>\n",
       "      <td>56</td>\n",
       "      <td>257</td>\n",
       "      <td>一般</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.496831</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>24</td>\n",
       "      <td>2</td>\n",
       "      <td>一般</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000851</td>\n",
       "      <td>0.001120</td>\n",
       "      <td>-0.000152</td>\n",
       "      <td>0.000703</td>\n",
       "      <td>0.001360</td>\n",
       "      <td>0.001832</td>\n",
       "      <td>-0.000563</td>\n",
       "      <td>-0.002091</td>\n",
       "      <td>0.001959</td>\n",
       "      <td>-0.003571</td>\n",
       "      <td>0.000390</td>\n",
       "      <td>-0.015447</td>\n",
       "      <td>-0.002157</td>\n",
       "      <td>0.012254</td>\n",
       "      <td>-0.006067</td>\n",
       "      <td>0.004736</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>12</td>\n",
       "      <td>8</td>\n",
       "      <td>101</td>\n",
       "      <td>353</td>\n",
       "      <td>消防知识要加强</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.056331</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>24</td>\n",
       "      <td>7</td>\n",
       "      <td>消防 知识 要 加强</td>\n",
       "      <td>4</td>\n",
       "      <td>0.011752</td>\n",
       "      <td>0.004001</td>\n",
       "      <td>0.015744</td>\n",
       "      <td>0.005031</td>\n",
       "      <td>-0.002309</td>\n",
       "      <td>0.005014</td>\n",
       "      <td>-0.002317</td>\n",
       "      <td>0.014483</td>\n",
       "      <td>0.023243</td>\n",
       "      <td>-0.001660</td>\n",
       "      <td>0.010149</td>\n",
       "      <td>-0.011126</td>\n",
       "      <td>0.000899</td>\n",
       "      <td>-0.008132</td>\n",
       "      <td>0.006392</td>\n",
       "      <td>0.009070</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>12</td>\n",
       "      <td>8</td>\n",
       "      <td>56</td>\n",
       "      <td>257</td>\n",
       "      <td>消防通道有货物摆放 清理不及时</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.817928</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>24</td>\n",
       "      <td>15</td>\n",
       "      <td>消防通道 有 货物 摆放   清理 不 及时</td>\n",
       "      <td>9</td>\n",
       "      <td>0.240184</td>\n",
       "      <td>-0.047535</td>\n",
       "      <td>-0.055964</td>\n",
       "      <td>0.169164</td>\n",
       "      <td>0.134602</td>\n",
       "      <td>0.178350</td>\n",
       "      <td>-0.103931</td>\n",
       "      <td>0.044703</td>\n",
       "      <td>0.084637</td>\n",
       "      <td>0.275877</td>\n",
       "      <td>0.350776</td>\n",
       "      <td>-0.365059</td>\n",
       "      <td>-0.030264</td>\n",
       "      <td>0.063569</td>\n",
       "      <td>-0.077093</td>\n",
       "      <td>-0.043314</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>12</td>\n",
       "      <td>8</td>\n",
       "      <td>56</td>\n",
       "      <td>287</td>\n",
       "      <td>防火门打开状态</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-2.853796</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>33</td>\n",
       "      <td>7</td>\n",
       "      <td>防火门 打开 状态</td>\n",
       "      <td>3</td>\n",
       "      <td>0.010451</td>\n",
       "      <td>0.008915</td>\n",
       "      <td>0.010056</td>\n",
       "      <td>0.005501</td>\n",
       "      <td>0.001714</td>\n",
       "      <td>0.012143</td>\n",
       "      <td>0.003139</td>\n",
       "      <td>0.055971</td>\n",
       "      <td>0.036108</td>\n",
       "      <td>-0.037276</td>\n",
       "      <td>0.051451</td>\n",
       "      <td>-0.002573</td>\n",
       "      <td>-0.097110</td>\n",
       "      <td>-0.042557</td>\n",
       "      <td>0.053756</td>\n",
       "      <td>-0.011044</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>5</td>\n",
       "      <td>12</td>\n",
       "      <td>8</td>\n",
       "      <td>101</td>\n",
       "      <td>364</td>\n",
       "      <td>防爆柜里面稀释剂，机油费混装</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.080718</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>40</td>\n",
       "      <td>14</td>\n",
       "      <td>防爆 柜 里面 稀释剂 ， 机油 费 混装</td>\n",
       "      <td>8</td>\n",
       "      <td>0.002810</td>\n",
       "      <td>0.001136</td>\n",
       "      <td>0.005998</td>\n",
       "      <td>-0.001037</td>\n",
       "      <td>0.001480</td>\n",
       "      <td>0.007708</td>\n",
       "      <td>-0.010797</td>\n",
       "      <td>0.006677</td>\n",
       "      <td>-0.000775</td>\n",
       "      <td>-0.005653</td>\n",
       "      <td>-0.006180</td>\n",
       "      <td>-0.006884</td>\n",
       "      <td>0.002432</td>\n",
       "      <td>-0.006244</td>\n",
       "      <td>-0.004764</td>\n",
       "      <td>-0.014633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>6</td>\n",
       "      <td>12</td>\n",
       "      <td>8</td>\n",
       "      <td>56</td>\n",
       "      <td>178</td>\n",
       "      <td>已经整改</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.261704</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>34</td>\n",
       "      <td>4</td>\n",
       "      <td>已经 整改</td>\n",
       "      <td>2</td>\n",
       "      <td>0.186023</td>\n",
       "      <td>0.381582</td>\n",
       "      <td>-0.098095</td>\n",
       "      <td>-0.068713</td>\n",
       "      <td>-0.018996</td>\n",
       "      <td>-0.014279</td>\n",
       "      <td>0.001265</td>\n",
       "      <td>0.009713</td>\n",
       "      <td>-0.016689</td>\n",
       "      <td>-0.012468</td>\n",
       "      <td>0.004466</td>\n",
       "      <td>-0.024171</td>\n",
       "      <td>-0.014421</td>\n",
       "      <td>0.003721</td>\n",
       "      <td>-0.051968</td>\n",
       "      <td>-0.002361</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>7</td>\n",
       "      <td>12</td>\n",
       "      <td>8</td>\n",
       "      <td>56</td>\n",
       "      <td>178</td>\n",
       "      <td>逃生通道有货物阻挡。</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.527178</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>34</td>\n",
       "      <td>10</td>\n",
       "      <td>逃生 通道 有 货物 阻挡 。</td>\n",
       "      <td>6</td>\n",
       "      <td>0.254065</td>\n",
       "      <td>-0.111572</td>\n",
       "      <td>-0.077983</td>\n",
       "      <td>0.104665</td>\n",
       "      <td>0.214997</td>\n",
       "      <td>-0.063296</td>\n",
       "      <td>-0.012981</td>\n",
       "      <td>0.074042</td>\n",
       "      <td>0.053388</td>\n",
       "      <td>0.055140</td>\n",
       "      <td>-0.098229</td>\n",
       "      <td>-0.248198</td>\n",
       "      <td>-0.097865</td>\n",
       "      <td>0.132346</td>\n",
       "      <td>-0.079019</td>\n",
       "      <td>0.035055</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>8</td>\n",
       "      <td>12</td>\n",
       "      <td>8</td>\n",
       "      <td>101</td>\n",
       "      <td>180</td>\n",
       "      <td>已整改</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.132907</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>28</td>\n",
       "      <td>3</td>\n",
       "      <td>已 整改</td>\n",
       "      <td>2</td>\n",
       "      <td>0.397351</td>\n",
       "      <td>0.856610</td>\n",
       "      <td>-0.240545</td>\n",
       "      <td>-0.164496</td>\n",
       "      <td>-0.048351</td>\n",
       "      <td>-0.055773</td>\n",
       "      <td>0.010368</td>\n",
       "      <td>-0.020039</td>\n",
       "      <td>-0.062835</td>\n",
       "      <td>-0.025351</td>\n",
       "      <td>0.002410</td>\n",
       "      <td>0.019105</td>\n",
       "      <td>0.007217</td>\n",
       "      <td>0.023712</td>\n",
       "      <td>-0.031490</td>\n",
       "      <td>0.009754</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>9</td>\n",
       "      <td>12</td>\n",
       "      <td>70</td>\n",
       "      <td>7</td>\n",
       "      <td>38</td>\n",
       "      <td>通道黄色线标脱落，已及时重新标好线标。</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.748759</td>\n",
       "      <td>17</td>\n",
       "      <td>7</td>\n",
       "      <td>6</td>\n",
       "      <td>55</td>\n",
       "      <td>19</td>\n",
       "      <td>通道 黄色 线标 脱落 ， 已 及时 重新 标好 线标 。</td>\n",
       "      <td>11</td>\n",
       "      <td>0.101173</td>\n",
       "      <td>0.045701</td>\n",
       "      <td>-0.014822</td>\n",
       "      <td>-0.002153</td>\n",
       "      <td>0.009648</td>\n",
       "      <td>-0.048944</td>\n",
       "      <td>-0.029479</td>\n",
       "      <td>0.017528</td>\n",
       "      <td>0.021423</td>\n",
       "      <td>-0.026241</td>\n",
       "      <td>-0.024783</td>\n",
       "      <td>-0.011687</td>\n",
       "      <td>-0.009744</td>\n",
       "      <td>-0.025672</td>\n",
       "      <td>0.024590</td>\n",
       "      <td>-0.014820</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id  level_1  level_2  level_3  level_4                     content  label  \\\n",
       "0   0       12       32      170      136  使用移动手动电动工具,外接线绝缘皮破损,应停止使用.    0.0   \n",
       "1   1       12        8       56      257                          一般    1.0   \n",
       "2   2       12        8      101      353                     消防知识要加强    0.0   \n",
       "3   3       12        8       56      257             消防通道有货物摆放 清理不及时    0.0   \n",
       "4   4       12        8       56      287                     防火门打开状态    0.0   \n",
       "5   5       12        8      101      364              防爆柜里面稀释剂，机油费混装    0.0   \n",
       "6   6       12        8       56      178                        已经整改    1.0   \n",
       "7   7       12        8       56      178                  逃生通道有货物阻挡。    0.0   \n",
       "8   8       12        8      101      180                         已整改    1.0   \n",
       "9   9       12       70        7       38         通道黄色线标脱落，已及时重新标好线标。    0.0   \n",
       "\n",
       "   bert_pred  level_1_strlen  level_2_strlen  level_3_strlen  level_4_strlen  \\\n",
       "0  -3.096279              17               7              16              49   \n",
       "1   2.496831              17               7               6              24   \n",
       "2   1.056331              17               7               6              24   \n",
       "3  -3.817928              17               7               6              24   \n",
       "4  -2.853796              17               7               6              33   \n",
       "5  -3.080718              17               7               6              40   \n",
       "6   3.261704              17               7               6              34   \n",
       "7  -3.527178              17               7               6              34   \n",
       "8   3.132907              17               7               6              28   \n",
       "9  -3.748759              17               7               6              55   \n",
       "\n",
       "   content_strlen                              content_seg  content_word_cnt  \\\n",
       "0              26  使用 移动 手动 电动工具 , 外 接线 绝缘 皮 破损 , 应 停止使用 .                14   \n",
       "1               2                                       一般                 1   \n",
       "2               7                               消防 知识 要 加强                 4   \n",
       "3              15                   消防通道 有 货物 摆放   清理 不 及时                 9   \n",
       "4               7                                防火门 打开 状态                 3   \n",
       "5              14                    防爆 柜 里面 稀释剂 ， 机油 费 混装                 8   \n",
       "6               4                                    已经 整改                 2   \n",
       "7              10                          逃生 通道 有 货物 阻挡 。                 6   \n",
       "8               3                                     已 整改                 2   \n",
       "9              19            通道 黄色 线标 脱落 ， 已 及时 重新 标好 线标 。                11   \n",
       "\n",
       "   content_tfidf_0  content_tfidf_1  content_tfidf_2  content_tfidf_3  \\\n",
       "0         0.006962         0.004836         0.009002         0.005316   \n",
       "1         0.000851         0.001120        -0.000152         0.000703   \n",
       "2         0.011752         0.004001         0.015744         0.005031   \n",
       "3         0.240184        -0.047535        -0.055964         0.169164   \n",
       "4         0.010451         0.008915         0.010056         0.005501   \n",
       "5         0.002810         0.001136         0.005998        -0.001037   \n",
       "6         0.186023         0.381582        -0.098095        -0.068713   \n",
       "7         0.254065        -0.111572        -0.077983         0.104665   \n",
       "8         0.397351         0.856610        -0.240545        -0.164496   \n",
       "9         0.101173         0.045701        -0.014822        -0.002153   \n",
       "\n",
       "   content_tfidf_4  content_tfidf_5  content_tfidf_6  content_tfidf_7  \\\n",
       "0         0.001656         0.002731        -0.004912         0.011597   \n",
       "1         0.001360         0.001832        -0.000563        -0.002091   \n",
       "2        -0.002309         0.005014        -0.002317         0.014483   \n",
       "3         0.134602         0.178350        -0.103931         0.044703   \n",
       "4         0.001714         0.012143         0.003139         0.055971   \n",
       "5         0.001480         0.007708        -0.010797         0.006677   \n",
       "6        -0.018996        -0.014279         0.001265         0.009713   \n",
       "7         0.214997        -0.063296        -0.012981         0.074042   \n",
       "8        -0.048351        -0.055773         0.010368        -0.020039   \n",
       "9         0.009648        -0.048944        -0.029479         0.017528   \n",
       "\n",
       "   content_tfidf_8  content_tfidf_9  content_tfidf_10  content_tfidf_11  \\\n",
       "0         0.011524        -0.001966          0.007598         -0.011693   \n",
       "1         0.001959        -0.003571          0.000390         -0.015447   \n",
       "2         0.023243        -0.001660          0.010149         -0.011126   \n",
       "3         0.084637         0.275877          0.350776         -0.365059   \n",
       "4         0.036108        -0.037276          0.051451         -0.002573   \n",
       "5        -0.000775        -0.005653         -0.006180         -0.006884   \n",
       "6        -0.016689        -0.012468          0.004466         -0.024171   \n",
       "7         0.053388         0.055140         -0.098229         -0.248198   \n",
       "8        -0.062835        -0.025351          0.002410          0.019105   \n",
       "9         0.021423        -0.026241         -0.024783         -0.011687   \n",
       "\n",
       "   content_tfidf_12  content_tfidf_13  content_tfidf_14  content_tfidf_15  \n",
       "0         -0.008630         -0.008527          0.010131          0.003516  \n",
       "1         -0.002157          0.012254         -0.006067          0.004736  \n",
       "2          0.000899         -0.008132          0.006392          0.009070  \n",
       "3         -0.030264          0.063569         -0.077093         -0.043314  \n",
       "4         -0.097110         -0.042557          0.053756         -0.011044  \n",
       "5          0.002432         -0.006244         -0.004764         -0.014633  \n",
       "6         -0.014421          0.003721         -0.051968         -0.002361  \n",
       "7         -0.097865          0.132346         -0.079019          0.035055  \n",
       "8          0.007217          0.023712         -0.031490          0.009754  \n",
       "9         -0.009744         -0.025672          0.024590         -0.014820  "
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(12000, 29) (18000, 29)\n"
     ]
    }
   ],
   "source": [
    "df.drop(['content', 'content_seg'], axis=1, inplace=True)\n",
    "train = df[df['label'].notna()]\n",
    "test = df[df['label'].isna()]\n",
    "\n",
    "print(train.shape, test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Fold_1 Training ================================\n",
      "\n",
      "[LightGBM] [Warning] feature_fraction is set=0.8, colsample_bytree=1.0 will be ignored. Current value: feature_fraction=0.8\n",
      "Training until validation scores don't improve for 50 rounds\n",
      "Early stopping, best iteration is:\n",
      "[12]\ttrain's auc: 0.994687\tvalid's auc: 0.99433\n",
      "\n",
      "Fold_2 Training ================================\n",
      "\n",
      "[LightGBM] [Warning] feature_fraction is set=0.8, colsample_bytree=1.0 will be ignored. Current value: feature_fraction=0.8\n",
      "Training until validation scores don't improve for 50 rounds\n",
      "Early stopping, best iteration is:\n",
      "[16]\ttrain's auc: 0.994427\tvalid's auc: 0.996004\n",
      "\n",
      "Fold_3 Training ================================\n",
      "\n",
      "[LightGBM] [Warning] feature_fraction is set=0.8, colsample_bytree=1.0 will be ignored. Current value: feature_fraction=0.8\n",
      "Training until validation scores don't improve for 50 rounds\n",
      "[100]\ttrain's auc: 0.998878\tvalid's auc: 0.990012\n",
      "Early stopping, best iteration is:\n",
      "[79]\ttrain's auc: 0.998387\tvalid's auc: 0.990388\n",
      "\n",
      "Fold_4 Training ================================\n",
      "\n",
      "[LightGBM] [Warning] feature_fraction is set=0.8, colsample_bytree=1.0 will be ignored. Current value: feature_fraction=0.8\n",
      "Training until validation scores don't improve for 50 rounds\n",
      "[100]\ttrain's auc: 0.998891\tvalid's auc: 0.987676\n",
      "Early stopping, best iteration is:\n",
      "[106]\ttrain's auc: 0.998994\tvalid's auc: 0.987761\n",
      "\n",
      "Fold_5 Training ================================\n",
      "\n",
      "[LightGBM] [Warning] feature_fraction is set=0.8, colsample_bytree=1.0 will be ignored. Current value: feature_fraction=0.8\n",
      "Training until validation scores don't improve for 50 rounds\n",
      "[100]\ttrain's auc: 0.998777\tvalid's auc: 0.991013\n",
      "Early stopping, best iteration is:\n",
      "[84]\ttrain's auc: 0.998464\tvalid's auc: 0.992256\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>column</th>\n",
       "      <th>importance</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>bert_pred</td>\n",
       "      <td>135.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>content_tfidf_13</td>\n",
       "      <td>36.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>content_tfidf_10</td>\n",
       "      <td>35.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>content_tfidf_0</td>\n",
       "      <td>33.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>content_strlen</td>\n",
       "      <td>33.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>content_tfidf_3</td>\n",
       "      <td>32.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>content_tfidf_11</td>\n",
       "      <td>29.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>content_tfidf_6</td>\n",
       "      <td>27.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>content_tfidf_7</td>\n",
       "      <td>27.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>content_tfidf_15</td>\n",
       "      <td>26.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>level_4</td>\n",
       "      <td>25.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>content_tfidf_14</td>\n",
       "      <td>23.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>content_tfidf_9</td>\n",
       "      <td>23.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>content_tfidf_5</td>\n",
       "      <td>22.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>content_tfidf_8</td>\n",
       "      <td>20.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>content_tfidf_1</td>\n",
       "      <td>20.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>content_tfidf_4</td>\n",
       "      <td>20.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>level_4_strlen</td>\n",
       "      <td>18.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>content_tfidf_2</td>\n",
       "      <td>16.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>content_tfidf_12</td>\n",
       "      <td>15.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>content_word_cnt</td>\n",
       "      <td>12.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>level_3</td>\n",
       "      <td>11.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>level_2</td>\n",
       "      <td>7.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>level_2_strlen</td>\n",
       "      <td>5.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>level_1</td>\n",
       "      <td>5.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>level_1_strlen</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>level_3_strlen</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              column  importance\n",
       "0          bert_pred       135.2\n",
       "1   content_tfidf_13        36.2\n",
       "2   content_tfidf_10        35.8\n",
       "3    content_tfidf_0        33.6\n",
       "4     content_strlen        33.0\n",
       "5    content_tfidf_3        32.8\n",
       "6   content_tfidf_11        29.4\n",
       "7    content_tfidf_6        27.4\n",
       "8    content_tfidf_7        27.0\n",
       "9   content_tfidf_15        26.2\n",
       "10           level_4        25.4\n",
       "11  content_tfidf_14        23.6\n",
       "12   content_tfidf_9        23.6\n",
       "13   content_tfidf_5        22.8\n",
       "14   content_tfidf_8        20.4\n",
       "15   content_tfidf_1        20.4\n",
       "16   content_tfidf_4        20.2\n",
       "17    level_4_strlen        18.8\n",
       "18   content_tfidf_2        16.4\n",
       "19  content_tfidf_12        15.6\n",
       "20  content_word_cnt        12.2\n",
       "21           level_3        11.6\n",
       "22           level_2         7.2\n",
       "23    level_2_strlen         5.8\n",
       "24           level_1         5.4\n",
       "25    level_1_strlen         2.0\n",
       "26    level_3_strlen         2.0"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ycol = 'label'\n",
    "feature_names = list(\n",
    "    filter(lambda x: x not in [ycol, 'id'], train.columns))\n",
    "\n",
    "model = lgb.LGBMClassifier(objective='binary',\n",
    "                           boosting_type='gbdt',\n",
    "                           tree_learner='serial',\n",
    "                           num_leaves=32,\n",
    "                           max_depth=4,\n",
    "                           learning_rate=0.1,\n",
    "                           n_estimators=1000,\n",
    "                           subsample=0.8,\n",
    "                           feature_fraction=0.8,\n",
    "                           reg_alpha=2,\n",
    "                           reg_lambda=3,\n",
    "                           random_state=2021,\n",
    "                           is_unbalance=True,\n",
    "                           metric='auc')\n",
    "\n",
    "\n",
    "oof = []\n",
    "prediction = test[['id']]\n",
    "prediction[ycol] = 0\n",
    "df_importance_list = []\n",
    "\n",
    "kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=2021)\n",
    "for fold_id, (trn_idx, val_idx) in enumerate(kfold.split(train[feature_names], train[ycol])):\n",
    "    X_train = train.iloc[trn_idx][feature_names]\n",
    "    Y_train = train.iloc[trn_idx][ycol]\n",
    "\n",
    "    X_val = train.iloc[val_idx][feature_names]\n",
    "    Y_val = train.iloc[val_idx][ycol]\n",
    "\n",
    "    print('\\nFold_{} Training ================================\\n'.format(fold_id+1))\n",
    "\n",
    "    lgb_model = model.fit(X_train,\n",
    "                          Y_train,\n",
    "                          eval_names=['train', 'valid'],\n",
    "                          eval_set=[(X_train, Y_train), (X_val, Y_val)],\n",
    "                          verbose=100,\n",
    "                          eval_metric='auc',\n",
    "                          early_stopping_rounds=50)\n",
    "\n",
    "    pred_val = lgb_model.predict_proba(\n",
    "        X_val, num_iteration=lgb_model.best_iteration_)\n",
    "    df_oof = train.iloc[val_idx][['id', ycol]].copy()\n",
    "    df_oof['pred'] = pred_val[:,1]\n",
    "    oof.append(df_oof)\n",
    "\n",
    "    pred_test = lgb_model.predict_proba(\n",
    "        test[feature_names], num_iteration=lgb_model.best_iteration_)\n",
    "    prediction[ycol] += pred_test[:,1] / kfold.n_splits\n",
    "\n",
    "    df_importance = pd.DataFrame({\n",
    "        'column': feature_names,\n",
    "        'importance': lgb_model.feature_importances_,\n",
    "    })\n",
    "    df_importance_list.append(df_importance)\n",
    "\n",
    "    del lgb_model, pred_val, pred_test, X_train, Y_train, X_val, Y_val\n",
    "    gc.collect()\n",
    "    \n",
    "    \n",
    "df_importance = pd.concat(df_importance_list)\n",
    "df_importance = df_importance.groupby(['column'])['importance'].agg(\n",
    "    'mean').sort_values(ascending=False).reset_index()\n",
    "df_importance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1 f1_score: 0.7721557804354493\n",
      "0.15000000000000002 f1_score: 0.81060116354234\n",
      "0.20000000000000004 f1_score: 0.8255544521681561\n",
      "0.25000000000000006 f1_score: 0.8336686787391012\n",
      "0.30000000000000004 f1_score: 0.8449744463373083\n",
      "0.3500000000000001 f1_score: 0.8494845360824743\n",
      "0.40000000000000013 f1_score: 0.8579367836054185\n",
      "0.45000000000000007 f1_score: 0.8634764250527798\n",
      "0.5000000000000001 f1_score: 0.873481057898499\n",
      "0.5500000000000002 f1_score: 0.8797675263349073\n",
      "0.6000000000000002 f1_score: 0.8918617614269789\n",
      "0.6500000000000001 f1_score: 0.89913109180204\n",
      "0.7000000000000002 f1_score: 0.9004594180704442\n",
      "0.7500000000000002 f1_score: 0.8566735112936346\n",
      "0.8000000000000002 f1_score: 0.8020969855832242\n",
      "0.8500000000000002 f1_score: 0.6722276741903829\n",
      "0.9000000000000002 f1_score: 0.6676602086438153\n",
      "0.9500000000000003 f1_score: 0.6596173212487412\n"
     ]
    }
   ],
   "source": [
    "i_bst = 0\n",
    "bst = 0\n",
    "df_oof = pd.concat(oof)\n",
    "for i in np.arange(0.1, 1.0, 0.05):\n",
    "    df_oof['pred_label'] = df_oof['pred'].apply(lambda x: 1 if x >= i else 0)\n",
    "    score = f1_score(df_oof['label'], df_oof['pred_label'])\n",
    "    print(i, 'f1_score:', score)\n",
    "    if score> bst:\n",
    "        i_bst = i\n",
    "        bst = score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    16074\n",
       "1     1926\n",
       "Name: label, dtype: int64"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prediction['label'] = prediction['label'].apply(lambda x: 1 if x >= i_bst else 0)\n",
    "prediction['label'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction[['id', 'label']].to_csv(f'submission_{bst}.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
